#!/usr/bin/env python3
"""
state_analysis.py

Analyze Ollama Llama3 outputs across:
- hardware: CPU vs GPU
- regime/condition: BaselineNoJSON, BaselineWithJSON, HighRigor, PushbackStrong
- lifecycle/state groups:
    * Cold: C1  (first run after service restart)
    * Warm: W*  (default W2/W3/W4 as in your naming)
    * AltEnv: A* (default A1/A2/A3; different terminal env)

It computes, per condition and per hardware:
- Within-group determinism (exact match across the group's runs, per item 1..50)
- Cold vs Warm divergence (C1 vs warm consensus, per item)
- Alt vs Warm divergence (A-consensus vs warm consensus, per item)
- Cold vs Alt divergence (C1 vs A-consensus, per item)
- Mean wordcount per item per group (Warm, Cold, Alt)
- Paired stats on wordcount differences between groups (optional SciPy)

It also computes, per condition, CPU vs GPU comparisons for each group (Warm primary by default):
- Exact match count (consensus vs consensus)
- Wordcount mean difference and paired tests (CPU-GPU)

Usage:
  python state_analysis.py --base_dir "/path/to/logs" --out_prefix results --warm_tags W2 W3 W4 --alt_tags A1 A2 A3

Outputs:
  results_summary.csv
  results_itemwise_divergence.csv  (per item counts for each comparison)
"""

import argparse
import csv
import math
import re
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# Optional stats
try:
    from scipy.stats import ttest_rel, wilcoxon
    SCIPY_OK = True
except Exception:
    SCIPY_OK = False

ITEM_SPLIT_RE = re.compile(r"(?m)^\s*(\d{1,2})\.\s")

CONDITIONS = ["BaselineNoJSON", "BaselineWithJSON", "HighRigor", "PushbackStrong"]
HARDWARES = ["CPU", "GPU"]

def trim_item50_noise(text: str) -> str:
    """
    Trim known non-answer footer text that sometimes appears at the end of item 50.
    """
    m = re.search(r"(?ims)\r?\n\r?\n+please note\b", text)
    if m:
        return text[:m.start()].rstrip()
    return text.rstrip()

def extract_items(text: str) -> Dict[int, str]:
    splits = ITEM_SPLIT_RE.split(text)
    items: Dict[int, str] = {}
    for i in range(1, len(splits), 2):
        num = int(splits[i])
        content = splits[i + 1].strip()
        items[num] = content
    if 50 in items:
        items[50] = trim_item50_noise(items[50])
    expected = set(range(1, 51))
    if set(items.keys()) != expected:
        raise ValueError(f"Bad item keys. Missing: {sorted(expected - set(items.keys()))} Extra: {sorted(set(items.keys()) - expected)}")
    return items

def normalize(s: str) -> str:
    return re.sub(r"\s+", " ", s.strip())

def word_count(s: str) -> int:
    return len(s.split())

def mean(xs: List[float]) -> float:
    return sum(xs) / len(xs) if xs else float("nan")

def stdev(xs: List[float]) -> float:
    if len(xs) < 2:
        return float("nan")
    m = mean(xs)
    return math.sqrt(sum((x - m) ** 2 for x in xs) / (len(xs) - 1))

def cohens_dz(diffs: List[float]) -> float:
    sd = stdev(diffs)
    if not diffs or sd == 0 or math.isnan(sd):
        return float("nan")
    return mean(diffs) / sd

def paired_stats(x: List[float], y: List[float], alternative: str = "greater"):
    if not SCIPY_OK:
        return (None, None, None, None)

    diffs = [xi - yi for xi, yi in zip(x, y)]

    # If all diffs are identical, Wilcoxon is undefined
    if len(set(diffs)) <= 1:
        # Perfect equality case
        if all(d == 0 for d in diffs):
            return (0.0, 1.0, 0.0, 1.0)
        else:
            # Complete constant shift case
            return (float("inf"), 0.0, float("inf"), 0.0)

    t_res = ttest_rel(x, y)
    w_res = wilcoxon(x, y, alternative=alternative)
    return (float(t_res.statistic), float(t_res.pvalue),
            float(w_res.statistic), float(w_res.pvalue))

@dataclass
class RunGroup:
    name: str
    files: List[Path]

def build_filename(condition: str, hardware: str, tag: str) -> str:
    return f"{condition}_{hardware}_T0_{tag}.txt"

def load_run(path: Path) -> Dict[int, str]:
    text = path.read_text(encoding="utf-8", errors="ignore")
    return extract_items(text)

def load_runs(paths: List[Path]) -> List[Dict[int, str]]:
    return [load_run(p) for p in paths]

def within_identical_items(runs: List[Dict[int, str]]) -> int:
    identical = 0
    for i in range(1, 51):
        texts = [normalize(r[i]) for r in runs]
        if len(set(texts)) == 1:
            identical += 1
    return identical

def consensus_text(runs: List[Dict[int, str]], i: int) -> str:
    c = Counter(normalize(r[i]) for r in runs)
    return c.most_common(1)[0][0]

def group_consensus_items(runs: List[Dict[int, str]]) -> Dict[int, str]:
    return {i: consensus_text(runs, i) for i in range(1, 51)}

def exact_match_count(a_items: Dict[int, str], b_items: Dict[int, str]) -> int:
    return sum(1 for i in range(1, 51) if a_items[i] == b_items[i])

def mean_wordcount_per_item(runs: List[Dict[int, str]]) -> List[float]:
    vals: List[float] = []
    for i in range(1, 51):
        vals.append(mean([word_count(r[i]) for r in runs]))
    return vals

def ensure_files_exist(paths: List[Path]):
    missing = [str(p) for p in paths if not p.exists()]
    if missing:
        raise FileNotFoundError("Missing files:\n  " + "\n  ".join(missing))

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_dir", type=str, required=True, help="Directory containing the .txt run logs")
    ap.add_argument("--warm_tags", nargs="+", default=["W2", "W3", "W4"], help="Warm run tags (default: W2 W3 W4)")
    ap.add_argument("--alt_tags", nargs="+", default=["A1", "A2", "A3"], help="Alt-env run tags (default: A1 A2 A3)")
    ap.add_argument("--cold_tag", default="C1", help="Cold run tag (default: C1)")
    ap.add_argument("--out_prefix", default="results", help="Output prefix for CSVs")
    ap.add_argument("--primary_group", choices=["Warm", "Cold", "Alt"], default="Warm", help="Primary group for CPU vs GPU comparisons")
    args = ap.parse_args()

    base_dir = Path(args.base_dir)

    summary_rows = []
    itemwise_rows = []

    for condition in CONDITIONS:
        # Build groups per hardware
        groups = {}  # (hardware, groupname) -> RunGroup
        for hw in HARDWARES:
            cold_files = [base_dir / build_filename(condition, hw, args.cold_tag)]
            warm_files = [base_dir / build_filename(condition, hw, t) for t in args.warm_tags]
            alt_files  = [base_dir / build_filename(condition, hw, t) for t in args.alt_tags]

            ensure_files_exist(cold_files + warm_files + alt_files)

            groups[(hw, "Cold")] = RunGroup("Cold", cold_files)
            groups[(hw, "Warm")] = RunGroup("Warm", warm_files)
            groups[(hw, "Alt")]  = RunGroup("Alt",  alt_files)

        # Load runs and compute per-hardware/per-group stats
        per_hw = {}
        for hw in HARDWARES:
            per_hw[hw] = {}
            for gname in ["Cold", "Warm", "Alt"]:
                rg = groups[(hw, gname)]
                runs = load_runs(rg.files)
                within = within_identical_items(runs) if len(runs) > 1 else 50  # single cold run: determinism N/A; treat as 50 for convenience
                cons = group_consensus_items(runs) if len(runs) > 1 else {i: normalize(runs[0][i]) for i in range(1, 51)}
                wc = mean_wordcount_per_item(runs)
                per_hw[hw][gname] = {"runs": runs, "within": within, "cons": cons, "wc": wc}

        # Within-hardware determinism (Warm + Alt only meaningful)
        for hw in HARDWARES:
            for gname in ["Warm", "Alt"]:
                summary_rows.append({
                    "condition": condition,
                    "comparison": f"{hw}_{gname}_within",
                    "metric": "within_identical_items",
                    "value": per_hw[hw][gname]["within"],
                })

        # Cold/Warm/Alt divergences within hardware (exact-match counts vs consensus)
        for hw in HARDWARES:
            warm_cons = per_hw[hw]["Warm"]["cons"]
            alt_cons  = per_hw[hw]["Alt"]["cons"]
            cold_cons = per_hw[hw]["Cold"]["cons"]

            c_vs_w_same = exact_match_count(cold_cons, warm_cons)
            a_vs_w_same = exact_match_count(alt_cons, warm_cons)
            c_vs_a_same = exact_match_count(cold_cons, alt_cons)

            summary_rows += [
                {"condition": condition, "comparison": f"{hw}_Cold_vs_Warm", "metric": "exact_match_items", "value": c_vs_w_same},
                {"condition": condition, "comparison": f"{hw}_Alt_vs_Warm",  "metric": "exact_match_items", "value": a_vs_w_same},
                {"condition": condition, "comparison": f"{hw}_Cold_vs_Alt",  "metric": "exact_match_items", "value": c_vs_a_same},
            ]

            # Itemwise flags
            for i in range(1, 51):
                itemwise_rows.append({
                    "condition": condition,
                    "hardware": hw,
                    "item": i,
                    "cold_eq_warm": int(cold_cons[i] == warm_cons[i]),
                    "alt_eq_warm": int(alt_cons[i] == warm_cons[i]),
                    "cold_eq_alt": int(cold_cons[i] == alt_cons[i]),
                })

        # Wordcount paired stats between groups within each hardware
        for hw in HARDWARES:
            warm_wc = per_hw[hw]["Warm"]["wc"]
            alt_wc  = per_hw[hw]["Alt"]["wc"]
            cold_wc = per_hw[hw]["Cold"]["wc"]  # single-run means per-item counts for that run

            # Cold vs Warm (paired across 50 items)
            diffs_cw = [c - w for c, w in zip(cold_wc, warm_wc)]
            dz_cw = cohens_dz(diffs_cw)
            t_stat, t_p, w_stat, w_p = paired_stats(cold_wc, warm_wc, alternative="two-sided") if SCIPY_OK else (None, None, None, None)
            summary_rows.append({
                "condition": condition,
                "comparison": f"{hw}_Cold_vs_Warm",
                "metric": "wordcount_mean_diff_c_minus_w",
                "value": mean(diffs_cw),
            })
            summary_rows.append({
                "condition": condition,
                "comparison": f"{hw}_Cold_vs_Warm",
                "metric": "wordcount_cohens_dz",
                "value": dz_cw,
            })
            if SCIPY_OK:
                summary_rows.append({"condition": condition, "comparison": f"{hw}_Cold_vs_Warm", "metric": "paired_t_p", "value": t_p})
                summary_rows.append({"condition": condition, "comparison": f"{hw}_Cold_vs_Warm", "metric": "wilcoxon_p_two_sided", "value": w_p})

            # Alt vs Warm
            diffs_aw = [a - w for a, w in zip(alt_wc, warm_wc)]
            dz_aw = cohens_dz(diffs_aw)
            t_stat, t_p, w_stat, w_p = paired_stats(alt_wc, warm_wc, alternative="two-sided") if SCIPY_OK else (None, None, None, None)
            summary_rows.append({"condition": condition, "comparison": f"{hw}_Alt_vs_Warm", "metric": "wordcount_mean_diff_a_minus_w", "value": mean(diffs_aw)})
            summary_rows.append({"condition": condition, "comparison": f"{hw}_Alt_vs_Warm", "metric": "wordcount_cohens_dz", "value": dz_aw})
            if SCIPY_OK:
                summary_rows.append({"condition": condition, "comparison": f"{hw}_Alt_vs_Warm", "metric": "paired_t_p", "value": t_p})
                summary_rows.append({"condition": condition, "comparison": f"{hw}_Alt_vs_Warm", "metric": "wilcoxon_p_two_sided", "value": w_p})

            # Cold vs Alt
            diffs_ca = [c - a for c, a in zip(cold_wc, alt_wc)]
            dz_ca = cohens_dz(diffs_ca)
            t_stat, t_p, w_stat, w_p = paired_stats(cold_wc, alt_wc, alternative="two-sided") if SCIPY_OK else (None, None, None, None)
            summary_rows.append({"condition": condition, "comparison": f"{hw}_Cold_vs_Alt", "metric": "wordcount_mean_diff_c_minus_a", "value": mean(diffs_ca)})
            summary_rows.append({"condition": condition, "comparison": f"{hw}_Cold_vs_Alt", "metric": "wordcount_cohens_dz", "value": dz_ca})
            if SCIPY_OK:
                summary_rows.append({"condition": condition, "comparison": f"{hw}_Cold_vs_Alt", "metric": "paired_t_p", "value": t_p})
                summary_rows.append({"condition": condition, "comparison": f"{hw}_Cold_vs_Alt", "metric": "wilcoxon_p_two_sided", "value": w_p})

        # CPU vs GPU comparisons per group (Warm primary, but report all)
        for gname in ["Warm", "Cold", "Alt"]:
            cpu_cons = per_hw["CPU"][gname]["cons"]
            gpu_cons = per_hw["GPU"][gname]["cons"]
            same = exact_match_count(cpu_cons, gpu_cons)

            cpu_wc = per_hw["CPU"][gname]["wc"]
            gpu_wc = per_hw["GPU"][gname]["wc"]
            diffs = [c - g for c, g in zip(cpu_wc, gpu_wc)]
            dz = cohens_dz(diffs)
            t_stat, t_p, w_stat, w_p = paired_stats(cpu_wc, gpu_wc, alternative="greater") if SCIPY_OK else (None, None, None, None)

            summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "exact_match_items", "value": same})
            summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "wordcount_mean_cpu", "value": mean(cpu_wc)})
            summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "wordcount_mean_gpu", "value": mean(gpu_wc)})
            summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "wordcount_mean_diff_cpu_minus_gpu", "value": mean(diffs)})
            summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "wordcount_sd_diff_cpu_minus_gpu", "value": stdev(diffs)})
            summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "cohens_dz", "value": dz})
            if SCIPY_OK:
                summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "paired_t_p", "value": t_p})
                summary_rows.append({"condition": condition, "comparison": f"CPU_vs_GPU_{gname}", "metric": "wilcoxon_p_one_sided_cpu_gt_gpu", "value": w_p})

    # Write CSVs
    out_summary = base_dir / f"{args.out_prefix}_summary.csv"
    out_itemwise = base_dir / f"{args.out_prefix}_itemwise_divergence.csv"

    with out_summary.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["condition", "comparison", "metric", "value"])
        writer.writeheader()
        writer.writerows(summary_rows)

    with out_itemwise.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["condition", "hardware", "item", "cold_eq_warm", "alt_eq_warm", "cold_eq_alt"])
        writer.writeheader()
        writer.writerows(itemwise_rows)

    print(f"Wrote: {out_summary}")
    print(f"Wrote: {out_itemwise}")
    if not SCIPY_OK:
        print("Note: SciPy not available; p-values not computed. Install scipy to enable tests.")

if __name__ == "__main__":
    main()
